import matplotlib.pyplot as plt
import numpy as np


tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),    
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),    
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),    
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),    
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]

for i in range(len(tableau20)):
	r, g, b = tableau20[i]
	tableau20[i] = (r / 255., g / 255., b / 255.)

T = 8000
eta = 2.0

def get_value(value_str):
	_, value = value_str.split(':')
	return float(value)


def get_bound(bound_1, bound_2):
	c1 = eta / (1. - np.exp(-eta))
	c2 = 1. / (1. - np.exp(-eta))
	return c1 * bound_1 + c2 * bound_2


def get_info(fname, T):
	logs = open(fname, 'r').read().split('\n')[2: -1]
	d = {}
	ts, tr_accs, te_accs, l2_grads, SI_errs, bounds = [], [], [], [], [], []
	for line in logs:
		items = line.split(', ')
		if get_value(items[0]) >= T:
			break
		ts.append(get_value(items[0]))
		tr_accs.append(get_value(items[1]))
		te_accs.append(get_value(items[2]))
		l2_grads.append(get_value(items[3]))
		SI_err = get_value(items[4])
		bound_2 = get_value(items[5])
		SI_errs.append(SI_err)
		bounds.append(get_bound(SI_err, bound_2))

	return {'t' : ts, 'tr_acc': tr_accs, 'te_acc' : te_accs, 
			'l2_grads' : l2_grads, 'SI_err' : SI_errs, 'bound' : bounds}


def get_infos(L, R, T, fdir):
	infos = {'t' : [], 'tr_acc': [], 'te_acc' : [], 
	'l2_grads' : [], 'SI_err' : [], 'bound' : []}
	for i in range(L, R):
		fname = f'{fdir}/{i}.out'
		info = get_info(fname, T)
		for key in info:
			infos[key].append(info[key])
	
	for key in infos:
		mean = np.array(infos[key]).mean(axis=0)
		std = np.array(infos[key]).std(axis=0)
		infos[key] = (mean, std)
	return infos


def plot_errors(infos):
	x = infos['t'][0]
	tr_err = 1. - infos['tr_acc'][0]
	te_err = 1. - infos['te_acc'][0]
	SI_err = infos['SI_err'][0]
	plt.figure()
	plt.plot(x, tr_err, linewidth=2.0, color=tableau20[5])
	plt.plot(x, SI_err, linewidth=2.0, color=tableau20[3])
	plt.plot(x, te_err, linewidth=2.0, color=tableau20[9])

	plt.grid(axis='both', linestyle=':')
	plt.xlim(0, T)
	plt.ylim(0, 1)
	plt.xlabel('steps', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['train error $\\mathcal{R}(S)$', 
				'SI error $\\mathcal{R}(S_I)$', 
				'test error $\\mathcal{R}(\\mathcal{D})$'], 
			   fontsize=10)
	plt.show()
	plt.close()


def plot_test_bound(infos, x_min=0, x_max=100, err_bar=False):
	x = infos['t'][0][x_min:x_max]
	tr_acc = infos['tr_acc'][0][x_min:x_max]
	te_err_mean = 1. - infos['te_acc'][0][x_min:x_max]
	te_bound_mean = infos['bound'][0][x_min:x_max]
	print ('bound:', te_bound_mean[-1], 'tr_acc:', tr_acc[-1], 'SI_acc:', 1. - infos['SI_err'][0][-1], 'te_acc:', 1. - te_err_mean[-1])
	plt.figure()
	plt.plot(x, te_bound_mean, linewidth=2.0, color=tableau20[6])
	plt.plot(x, te_err_mean, linewidth=2.0, color=tableau20[9])
	
	if err_bar:
		te_bound_std = infos['bound'][1][x_min:x_max]
		te_err_std = infos['te_acc'][1][x_min:x_max]
		plt.fill_between(x, te_bound_mean - te_bound_std, te_bound_mean + te_bound_std, color = tableau20[6], alpha=0.1)
		plt.fill_between(x, te_err_mean - te_err_std, te_err_mean + te_err_std, color = tableau20[9], alpha=0.1)

	plt.xlabel('steps', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['our bound', 'test error'], fontsize=10)
	plt.grid(axis='both', linestyle=':')
	plt.xlim(x[0], T)
	plt.ylim(0,)
	plt.show()
	plt.close()


def plot_bound_detail(infos):
	x = infos['t'][0]
	SI_err = infos['SI_err'][0]
	eta = 1.
	c1 = eta / (1. - np.exp(-eta))
	
	bound1 = c1 * SI_err

	bound = infos['bound'][0]
	te_err = 1-infos['te_acc'][0]
	plt.figure()
	plt.plot(x, bound, linewidth=2.0, color=tableau20[6])
	plt.plot(x, te_err, linewidth=2.0, color=tableau20[9])
	
	plt.fill_between(x, 0., bound1, color = tableau20[6], alpha=0.15)
	plt.fill_between(x, bound1, bound, color = tableau20[4], alpha=0.15)
	plt.xlabel('steps', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['our bound', 'test error', '$\\eta C_{\\eta}\\mathcal{R}(S_I)$', '$C_{\\eta}\\mathrm{bound}_2$'], fontsize=10)
	plt.grid(axis='both', linestyle=':')
	plt.xlim(0, T)
	plt.ylim(0,1)
	plt.show()
	plt.close()


def plot_sgd_error(infos):
	x = infos['t'][0]
	tr_err = 1. - infos['tr_acc'][0]
	te_err = 1. - infos['te_acc'][0]
	plt.figure()
	plt.plot(x, tr_err, linewidth=2.0, color=tableau20[5])
	plt.plot(x, te_err, linewidth=2.0, color=tableau20[9])

	plt.grid(axis='both', linestyle=':')
	plt.xlim(0, T)
	plt.ylim(0, 1)
	plt.xlabel('steps', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['train error $\\mathcal{R}(S)$', 
				'test error $\\mathcal{R}(\\mathcal{D})$'], 
			   fontsize=20)
	plt.show()
	plt.close()


def get_grad_diff(infos):
	x = infos['t'][0]
	lr = 0.001
	grad_diffs = infos['l2_grads'][0]
	for i in range(x.shape[0]):
		t, grad_diff_norm2 = x[i], grad_diffs[i]
		grad_diffs[i] /= lr ** 2.
		if t % 200 == 0 and t > 0:
			lr = lr * 0.9
	return grad_diffs

def plot_grad_diff(infos):
	x = infos['t'][0]
	grad_diffs = get_grad_diff(infos)

	plt.figure()
	plt.plot(x, np.sqrt(grad_diffs), linewidth=2.0, color=tableau20[9])
	plt.grid(axis='both', linestyle=':')
	plt.xlim(0, T)
	plt.xlabel('steps', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['gradient difference norm'], 
			   fontsize=10)
	plt.show()
	plt.close()

def plot_mgrad(infos):
	x = infos['m'][0]
	y_mean, y_std = infos['sum_grad2']
	plt.figure()
	plt.plot(x, y_mean, linewidth=3.0, color=tableau20[1])
	plt.fill_between(x, y_mean - y_std, y_mean + y_std, color = tableau20[5], alpha=0.2)

	plt.xlabel('m = |J|', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['$\\sum_t || g_t ||^2$'], fontsize=18.5)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()

def plot_fsgd_sgd_detail(fgd_infos, gd_infos, x_min, err_bar=True):
	x = fgd_infos['t'][0][x_min:]
	fgd_tr_err_mean = 1. - fgd_infos['tr_acc'][0][x_min:]
	fgd_te_err_mean = 1.- fgd_infos['te_acc'][0][x_min:]

	gd_tr_err_mean = 1. - gd_infos['tr_acc'][0][x_min:]
	gd_te_err_mean = 1.- gd_infos['te_acc'][0][x_min:]

	plt.figure()
	plt.plot(x, fgd_te_err_mean, linewidth=2.0, color=tableau20[9])
	plt.plot(x, gd_te_err_mean, linewidth=2.0, color=tableau20[10])
	plt.plot(x, fgd_tr_err_mean, linewidth=2.0, color=tableau20[5])
	plt.plot(x, gd_tr_err_mean, linewidth=2.0, color=tableau20[6])
	

	if err_bar:
		fgd_tr_err_std = fgd_infos['tr_acc'][1][x_min:]
		fgd_te_err_std = fgd_infos['te_acc'][1][x_min:]
		gd_tr_err_std = gd_infos['tr_acc'][1][x_min:]
		gd_te_err_std = gd_infos['te_acc'][1][x_min:]
		plt.fill_between(x, fgd_tr_err_mean - fgd_tr_err_std, fgd_tr_err_mean + fgd_tr_err_std, color = tableau20[5], alpha=0.15)
		plt.fill_between(x, fgd_te_err_mean - fgd_te_err_std, fgd_te_err_mean + fgd_te_err_std, color = tableau20[9], alpha=0.15)
		plt.fill_between(x, gd_tr_err_mean - gd_tr_err_std, gd_tr_err_mean + gd_tr_err_std, color = tableau20[6], alpha=0.15)
		plt.fill_between(x, gd_te_err_mean - gd_te_err_std, gd_te_err_mean + gd_te_err_std, color = tableau20[10], alpha=0.15)

	plt.legend(['FSGD test error', 'SGD test error', 'FSGD train error', 'SGD train error'], fontsize=10)
	plt.grid(axis='both', linestyle=':')
	plt.xlabel('steps', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.show()
	plt.close()


def random_label(err_bar=True):
	info00 = get_infos(0, 100, 8000, fdir="log/cifar10/fgd")
	info01 = get_infos(0, 100, 8000, fdir="log/cifar10/random_label/0.1")
	info02 = get_infos(0, 100, 8000, fdir="log/cifar10/random_label/0.2")
	info05 = get_infos(0, 100, 8000, fdir="log/cifar10/random_label/0.5")

	x = info00['t'][0]

	plt.figure()
	plt.plot(x, info00["tr_acc"][0], linewidth=1.0, color=tableau20[2])
	plt.plot(x, info01["tr_acc"][0], linewidth=1.0, color=tableau20[4])
	plt.plot(x, info02["tr_acc"][0], linewidth=1.0, color=tableau20[6])
	plt.plot(x, info05["tr_acc"][0], linewidth=1.0, color=tableau20[6])
	
	if err_bar:
		plt.fill_between(x, info00["tr_acc"][0] - info00["tr_acc"][1], info00["tr_acc"][0] + info00["tr_acc"][1], color = tableau20[2], alpha=0.1)
		plt.fill_between(x, info01["tr_acc"][0] - info01["tr_acc"][1], info01["tr_acc"][0] + info01["tr_acc"][1], color = tableau20[4], alpha=0.1)
		plt.fill_between(x, info02["tr_acc"][0] - info02["tr_acc"][1], info02["tr_acc"][0] + info02["tr_acc"][1], color = tableau20[6], alpha=0.1)
		plt.fill_between(x, info05["tr_acc"][0] - info05["tr_acc"][1], info05["tr_acc"][0] + info05["tr_acc"][1], color = tableau20[6], alpha=0.1)

	plt.xlabel('steps', fontsize=20)
	plt.ylabel('training accuracy', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['p=0', 'p=0.1', 'p=0.2', 'p=0.5'], loc='upper left', fontsize=10)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()

	plt.figure()
	plt.plot(x, 1-info00["te_acc"][0], linewidth=1.0, color=tableau20[2])
	plt.plot(x, 1-info01["te_acc"][0], linewidth=1.0, color=tableau20[4])
	plt.plot(x, 1-info02["te_acc"][0], linewidth=1.0, color=tableau20[6])
	plt.plot(x, 1-info05["te_acc"][0], linewidth=1.0, color=tableau20[6])
	
	if err_bar:
		plt.fill_between(x, 1-info00["te_acc"][0] - info00["te_acc"][1], 1-info00["te_acc"][0] + info00["te_acc"][1], color = tableau20[2], alpha=0.1)
		plt.fill_between(x, 1-info01["te_acc"][0] - info01["te_acc"][1], 1-info01["te_acc"][0] + info01["te_acc"][1], color = tableau20[4], alpha=0.1)
		plt.fill_between(x, 1-info02["te_acc"][0] - info02["te_acc"][1], 1-info02["te_acc"][0] + info02["te_acc"][1], color = tableau20[6], alpha=0.1)
		plt.fill_between(x, 1-info05["te_acc"][0] - info05["te_acc"][1], 1-info05["te_acc"][0] + info05["te_acc"][1], color = tableau20[6], alpha=0.1)

	plt.xlabel('steps', fontsize=20)
	plt.ylabel('test error', fontsize=20)
	plt.ylim(0, 1)
	plt.tick_params(labelsize=20)
	plt.legend(['p=0', 'p=0.1', 'p=0.2', 'p=0.5'], loc='upper left', fontsize=10)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()

	plt.figure()
	plt.plot(x, info00["bound"][0], linewidth=1.0, color=tableau20[2])
	plt.plot(x, info01["bound"][0], linewidth=1.0, color=tableau20[4])
	plt.plot(x, info02["bound"][0], linewidth=1.0, color=tableau20[6])
	plt.plot(x, info05["bound"][0], linewidth=1.0, color=tableau20[6])
	
	if err_bar:
		plt.fill_between(x, info00["bound"][0] - info00["bound"][1], info00["bound"][0] + info00["bound"][1], color = tableau20[2], alpha=0.1)
		plt.fill_between(x, info01["bound"][0] - info01["bound"][1], info01["bound"][0] + info01["bound"][1], color = tableau20[4], alpha=0.1)
		plt.fill_between(x, info02["bound"][0] - info02["bound"][1], info02["bound"][0] + info02["bound"][1], color = tableau20[6], alpha=0.1)
		plt.fill_between(x, info05["bound"][0] - info02["bound"][1], info05["bound"][0] + info05["bound"][1], color = tableau20[6], alpha=0.1)

	plt.xlabel('steps', fontsize=20)
	plt.ylabel('our bound', fontsize=20)
	plt.ylim(0, 2)
	plt.tick_params(labelsize=20)
	plt.legend(['p=0', 'p=0.1', 'p=0.2', 'p=0.5'], loc='upper left', fontsize=10)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()

	plt.figure()
	plt.plot(x, np.sqrt(get_grad_diff(info00)), linewidth=1.0, color=tableau20[2])
	plt.plot(x, np.sqrt(get_grad_diff(info01)), linewidth=1.0, color=tableau20[4])
	plt.plot(x, np.sqrt(get_grad_diff(info02)), linewidth=1.0, color=tableau20[6])

	plt.xlabel('steps', fontsize=20)
	plt.ylabel('gradient difference norm', fontsize=20)
	plt.tick_params(labelsize=20)
	plt.legend(['p=0', 'p=0.1', 'p=0.2', 'p=0.5'], loc='upper left', fontsize=10)
	plt.grid(axis='both', linestyle=':')
	plt.show()
	plt.close()

def get_mgrad(k, task=""):
	infos = {'m' : [], 'sum_grad2': []}
	for i in range(k):
		fname = f'log/{task}/mgrad/{i}.out'
		logs = open(fname, 'r').read().split('\n')[:-1]
		ms, sum_grad2s = [], []
		for line in logs:
			items = line.split(', ')
			m = get_value(items[0])
			sum_grad2 = get_value(items[1])
			ms.append(m)
			sum_grad2s.append(sum_grad2)
		infos['m'].append(ms)
		infos['sum_grad2'].append(sum_grad2s)

	for key in infos.keys():
		mean = np.mean(infos[key], axis=0)
		std = np.std(infos[key], axis=0)
		infos[key] = (mean, std)

	return infos

if __name__ == '__main__':
    n = 100
    infos = get_infos(0, n, 8000, fdir='log/cifar10/fsgd')
    plot_errors(infos)
    plot_test_bound(infos, 0, 160, err_bar=True)
    plot_test_bound(infos, 80, 160, err_bar=True)
    plot_bound_detail(infos)
    plot_grad_diff(infos)

    mginfos = get_mgrad(n, task="cifar10")
    plot_mgrad(mginfos)
        
    sgdinfos = get_infos(0, n, 8000, fdir='log/cifar10/sgd')
    plot_sgd_error(infos)
    plot_fsgd_sgd_detail(infos, sgdinfos, 50)

    random_label(n)